from pathlib import Path
from typing import Annotated, Literal, TypeAlias

from pydantic import BaseModel, Field, model_validator

from prime_rl.trainer.config import (
    AdamWConfig,
    CheckpointConfig,
    ConstantSchedulerConfig,
    ModelConfig,
    OptimizerConfigType,
    SchedulerConfigType,
)
from prime_rl.utils.config import LogConfig, WandbMonitorConfig
from prime_rl.utils.pydantic_config import BaseConfig, BaseSettings


class LossConfig(BaseConfig):
    """Base config for loss."""

    ratio_type: Annotated[Literal["token", "sequence"], Field(description="Type of importance ratio to use.")] = "token"

    mask_ratio_high: Annotated[float, Field(ge=0)] = 8.0
    mask_ratio_low: Annotated[float, Field(ge=0)] = 0.125
    sequence_mask_ratio_low: Annotated[
        float,
        Field(
            ge=0,
            description=(
                "If set, masks entire sequences when any generated token has an importance ratio below this value."
            ),
        ),
    ] = 0.0
    kl_tau: Annotated[float, Field(ge=0)] = 0.0
    kl_mask_type: Annotated[Literal["masked", "unmasked", "all"], Field(description="Type of KL mask to use.")] = "all"


class FakeDataLoaderConfig(BaseConfig):
    """Configures a fake data loader sampling random micro batches for debugging."""

    batch_size: Annotated[int, Field(ge=1)] = 2
    seq_len: Annotated[int, Field(ge=1)] = 128


class DataLoaderConfig(BaseConfig):
    """Configures the data loader used for training."""

    fake: Annotated[FakeDataLoaderConfig | None, Field(description="Whether to use a fake data loader.")] = None


class FileSystemWeightBroadcastConfig(BaseModel):
    """Configures the weight broadcast."""

    type: Literal["filesystem"] = "filesystem"
    save_sharded: Annotated[bool, Field(description="Whether to save the weight checkpoint in sharded format.")] = True
    save_format: Annotated[
        Literal["safetensors", "torch"], Field(description="The format to save the weight checkpoint in.")
    ] = "safetensors"


class NCCLWeightBroadcastConfig(BaseModel):
    """Configures the NCCL broadcast."""

    type: Literal["nccl"] = "nccl"
    host: Annotated[str, Field(description="The host to use for the NCCL broadcast.")] = "localhost"
    port: Annotated[int, Field(description="The port to use for the NCCL broadcast.")] = 29501
    timeout: Annotated[int, Field(description="The timeout in seconds to use for the NCCL broadcast.")] = 1200
    # TODO: Should not be configurable, but auto-inferred
    inference_world_size: Annotated[int, Field(description="The number of GPUs used for inference.")] = 1


WeightBroadcastConfigType: TypeAlias = FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig


class RLTrainerConfig(BaseSettings):
    """Configures the RL trainer"""

    # The model configuration
    model: ModelConfig = ModelConfig()

    # The data configuration
    data: DataLoaderConfig = DataLoaderConfig()

    # The loss configuration
    loss: LossConfig = LossConfig()

    # The optimizer configuration
    optim: Annotated[OptimizerConfigType, Field(discriminator="type")] = AdamWConfig()

    # The learning rate scheduler configuration
    scheduler: Annotated[SchedulerConfigType, Field(discriminator="type")] = ConstantSchedulerConfig()

    # The checkpoint configuration
    ckpt: CheckpointConfig | None = None

    weight_broadcast: Annotated[WeightBroadcastConfigType, Field(discriminator="type")] = (
        FileSystemWeightBroadcastConfig()
    )

    # The logging configuration
    log: LogConfig = LogConfig()

    # The wandb configuration
    wandb: WandbMonitorConfig | None = None

    output_dir: Annotated[
        Path,
        Field(
            description="Directory to write outputs to. Will be populated with checkpoints, weights, rollouts and logs as subdirectories. Should be set to a persistent directory with enough disk space. This value should be distinct across experiments running on a single node. See the README for more details."
        ),
    ] = Path("outputs")

    max_steps: Annotated[
        int | None,
        Field(
            description="Maximum number of steps to run training for. If None, will run indefinitely.",
        ),
    ] = None

    max_async_level: Annotated[
        int,
        Field(
            ge=0,
            description="Maximum number of steps that inference can be ahead of training. Determines how 'off-policy' the inference engines can be. Higher values yield better throughput through async execution, but may yield lower powerofrmance. If 0, will be fully synchronous.",
        ),
    ] = 1

    memory_profiler_path: Annotated[Path | None, Field(description="Path to write memory profile to.")] = None

    bench: Annotated[
        bool,
        Field(
            description="Whether to run in benchmark mode. It will automatically set the maximum number of steps to run to 5 and use fake data.",
        ),
    ] = False

    trace_path: Annotated[Path | None, Field(description="Path to write pytorch profiler trace to.")] = None

    dist_timeout_seconds: Annotated[
        int,
        Field(
            description="Timeout in seconds for torch distributed ops. Defaults to 600 seconds.",
        ),
    ] = 600

    @model_validator(mode="after")
    def auto_setup_bench(self):
        if self.bench:
            self.max_steps = 4  # 1 Warmup + 3 Benchmark
            if not self.data.fake:
                self.data.fake = FakeDataLoaderConfig()
            if self.wandb:  # Do not log extras
                self.wandb.log_extras = None
            if self.ckpt:  # Do not checkpoint
                self.ckpt = None
        return self

    @model_validator(mode="after")
    def disable_logging_wandb_samples(self):
        if self.wandb and self.wandb.log_extras:
            self.wandb.log_extras.samples = False
        return self

    @model_validator(mode="after")
    def dont_do_massive_traces(self):
        if self.trace_path:
            if self.max_steps is None:
                raise ValueError("Must specify max_steps when tracing")
            if self.max_steps >= 10:
                raise ValueError(
                    "Tracing more than 10 steps is not recommended as your trace will be massive. Remove this line if you really want to trace more steps."
                )
        return self

    @model_validator(mode="after")
    def validate_lora_adapter_saving(self):
        if self.ckpt and self.ckpt.weights and self.ckpt.weights.save_adapter_separately:
            lora_enabled = self.model and self.model.experimental and self.model.experimental.lora
            if not lora_enabled:
                raise ValueError(
                    "save_adapter_separately=True requires LoRA to be enabled. "
                    "Set model.experimental.lora or disable save_adapter_separately."
                )
        return self

    @model_validator(mode="after")
    def validate_weight_broadcast_type(self):
        if self.weight_broadcast.type == "nccl" and self.max_async_level != 1:
            raise ValueError("NCCL weight broadcast only works with async level 1")
        return self

    @model_validator(mode="after")
    def validate_opt_and_fsdp_offload(self):
        if self.optim.type == "muon" and self.model.fsdp_cpu_offload:
            raise ValueError("Muon optimizer does not support FSDP CPU offload")
        return self
